{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# In google colab, \n",
    "# For master version of catalyst, uncomment:\n",
    "# (master version should be fully compatible with this notebook)\n",
    "# ! pip install git+git://github.com/catalyst-team/catalyst.git\n",
    "\n",
    "# For last release version of catalyst, uncomment:\n",
    "# ! pip install catalyst\n",
    "\n",
    "# For specific commit version of catalyst, uncomment:\n",
    "# ! pip install git+http://github.com/catalyst-team/catalyst.git@{commit_hash}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Segmentation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If you have Unet, all CV is segmentation now."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Goals\n",
    "\n",
    "- train Unet on isbi dataset\n",
    "- visualize the predictions"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Preparation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%bash \n",
    "\n",
    "# Get the data:\n",
    "function gdrive_download () {\n",
    "  CONFIRM=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate \"https://docs.google.com/uc?export=download&id=$1\" -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\\1\\n/p')\n",
    "  wget --load-cookies /tmp/cookies.txt \"https://docs.google.com/uc?export=download&confirm=$CONFIRM&id=$1\" -O $2\n",
    "  rm -rf /tmp/cookies.txt\n",
    "}\n",
    "\n",
    "mkdir -p data\n",
    "gdrive_download 1N82zh0kzmnzqRvUyMgVOGsCoS1kHf3RP ./data/isbi.tar.gz\n",
    "tar -xf ./data/isbi.tar.gz -C ./data/"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Final folder structure with training data:\n",
    "```bash\n",
    "catalyst-examples/\n",
    "    data/\n",
    "        isbi/\n",
    "            train-volume.tif\n",
    "            train-labels.tif\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ! pip install tifffile"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tifffile as tiff\n",
    "\n",
    "images = tiff.imread('./data/isbi/train-volume.tif')\n",
    "masks = tiff.imread('./data/isbi/train-labels.tif')\n",
    "\n",
    "data = list(zip(images, masks))\n",
    "\n",
    "train_data = data[:-4]\n",
    "valid_data = data[-4:]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import collections\n",
    "import numpy as np\n",
    "import torch\n",
    "import torchvision\n",
    "import torchvision.transforms as transforms\n",
    "from catalyst.data import Augmentor\n",
    "from catalyst.dl import utils\n",
    "\n",
    "bs = 4\n",
    "num_workers = 4\n",
    "\n",
    "data_transform = transforms.Compose([\n",
    "    Augmentor(\n",
    "        dict_key=\"features\",\n",
    "        augment_fn=lambda x: \\\n",
    "            torch.from_numpy(x.copy().astype(np.float32) / 255.).unsqueeze_(0)),\n",
    "    Augmentor(\n",
    "        dict_key=\"features\",\n",
    "        augment_fn=transforms.Normalize(\n",
    "            (0.5, ),\n",
    "            (0.5, ))),\n",
    "    Augmentor(\n",
    "        dict_key=\"targets\",\n",
    "        augment_fn=lambda x: \\\n",
    "            torch.from_numpy(x.copy().astype(np.float32) / 255.).unsqueeze_(0))\n",
    "])\n",
    "\n",
    "open_fn = lambda x: {\"features\": x[0], \"targets\": x[1]}\n",
    "\n",
    "loaders = collections.OrderedDict()\n",
    "\n",
    "train_loader = utils.get_loader(\n",
    "    train_data, \n",
    "    open_fn=open_fn, \n",
    "    dict_transform=data_transform, \n",
    "    batch_size=bs, \n",
    "    num_workers=num_workers, \n",
    "    shuffle=True)\n",
    "\n",
    "valid_loader = utils.get_loader(\n",
    "    valid_data, \n",
    "    open_fn=open_fn, \n",
    "    dict_transform=data_transform, \n",
    "    batch_size=bs, \n",
    "    num_workers=num_workers, \n",
    "    shuffle=False)\n",
    "\n",
    "loaders[\"train\"] = train_loader\n",
    "loaders[\"valid\"] = valid_loader"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from catalyst.contrib.models.cv import Unet"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "from catalyst.dl import SupervisedRunner\n",
    "\n",
    "# experiment setup\n",
    "num_epochs = 50\n",
    "logdir = \"./logs/segmentation_notebook\"\n",
    "\n",
    "# model, criterion, optimizer\n",
    "model = Unet(num_classes=1, in_channels=1, num_channels=64, num_blocks=4)\n",
    "criterion = nn.BCEWithLogitsLoss()\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)\n",
    "scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[10, 20, 40], gamma=0.3)\n",
    "\n",
    "\n",
    "# model runner\n",
    "runner = SupervisedRunner()\n",
    "\n",
    "# model training\n",
    "runner.train(\n",
    "    model=model,\n",
    "    criterion=criterion,\n",
    "    optimizer=optimizer,\n",
    "    loaders=loaders,\n",
    "    logdir=logdir,\n",
    "    num_epochs=num_epochs,\n",
    "    verbose=True\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Inference"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "runner_out = runner.predict_loader(\n",
    "    model, loaders[\"valid\"], resume=f\"{logdir}/checkpoints/best.pth\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Predictions visualization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "plt.style.use(\"ggplot\")\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sigmoid = lambda x: 1/(1 + np.exp(-x))\n",
    "\n",
    "for i, (input, output) in enumerate(zip(valid_data, runner_out)):\n",
    "    image, mask = input\n",
    "    \n",
    "    threshold = 0.5\n",
    "    \n",
    "    plt.figure(figsize=(10,8))\n",
    "    \n",
    "    plt.subplot(1, 3, 1)\n",
    "    plt.imshow(image, 'gray')\n",
    "    \n",
    "    plt.subplot(1, 3, 2)\n",
    "    output = sigmoid(output[0].copy())\n",
    "    output = (output > threshold).astype(np.uint8)\n",
    "    plt.imshow(output, 'gray')\n",
    "    \n",
    "    plt.subplot(1, 3, 3)\n",
    "    plt.imshow(mask, 'gray')\n",
    "    \n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.0"
  },
  "pycharm": {
   "stem_cell": {
    "cell_type": "raw",
    "source": [],
    "metadata": {
     "collapsed": false
    }
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}